{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 01. Block Poisson problem.\n",
    "\n",
    "In this tutorial we first solve the problem\n",
    "\n",
    "\\begin{cases}\n",
    "-u'' = f & \\text{in }\\Omega = (0, 1),\\\\\n",
    " u   = 0 & \\text{on }\\Gamma = \\{0, 1\\}\n",
    "\\end{cases}\n",
    "\n",
    "using non blocked FEniCSx code.\n",
    "\n",
    "Then we use block support of FEniCSx to solve the system\n",
    "\n",
    "\\begin{cases}\n",
    "- w_1'' - 2 w_2'' = 3 f & \\text{in }\\Omega,\\\\\n",
    "- 3 w_1'' - 4 w_2'' = 7 f & \\text{in }\\Omega\n",
    "\\end{cases}\n",
    "\n",
    "subject to\n",
    "\n",
    "\\begin{cases}\n",
    " w_1 = 0 & \\text{on }\\Gamma,\\\\\n",
    " w_2 = 0 & \\text{on }\\Gamma\n",
    "\\end{cases}\n",
    "\n",
    "By construction the solution of the system is\n",
    "$$(w_1, w_2) = (u, u)$$\n",
    "\n",
    "We then compare the obtained solutions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from mpi4py import MPI\n",
    "from petsc4py import PETSc\n",
    "from ufl import dx, grad, inner, sin, SpatialCoordinate, TestFunction, TrialFunction\n",
    "from dolfinx import DirichletBC, Function, FunctionSpace, UnitIntervalMesh\n",
    "from dolfinx.fem import (assemble_matrix_block, assemble_scalar, assemble_vector_block,\n",
    "                         create_vector_block, LinearProblem, locate_dofs_topological)\n",
    "from dolfinx.mesh import locate_entities_boundary\n",
    "from dolfinx.plot import create_vtk_topology\n",
    "from multiphenicsx.fem import BlockVecSubVectorWrapper\n",
    "import pyvista"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mesh generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mesh = UnitIntervalMesh(MPI.COMM_WORLD, 32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def left(x):\n",
    "    return abs(x[0] - 0.) < np.finfo(float).eps\n",
    "\n",
    "\n",
    "def right(x):\n",
    "    return abs(x[0] - 1.) < np.finfo(float).eps\n",
    "\n",
    "\n",
    "left_facets = locate_entities_boundary(mesh, mesh.topology.dim - 1, left)\n",
    "right_facets = locate_entities_boundary(mesh, mesh.topology.dim - 1, right)\n",
    "boundary_facets = np.hstack((left_facets, right_facets))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x0 = SpatialCoordinate(mesh)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dolfinx_to_pyvista_mesh(mesh):\n",
    "    num_cells = mesh.topology.index_map(mesh.topology.dim).size_local\n",
    "    cell_entities = np.arange(num_cells, dtype=np.int32)\n",
    "    pyvista_cells, cell_types = create_vtk_topology(mesh, mesh.topology.dim, cell_entities)\n",
    "    grid = pyvista.UnstructuredGrid(pyvista_cells, cell_types, mesh.geometry.x)\n",
    "    return grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pyvista_mesh_plot(mesh):\n",
    "    grid = dolfinx_to_pyvista_mesh(mesh)\n",
    "    plotter = pyvista.PlotterITK()\n",
    "    plotter.add_mesh(grid)\n",
    "    plotter.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_mesh_plot(mesh)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Standard case (solve for $u$ with FEniCSx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_standard():\n",
    "    # Define a function space\n",
    "    V = FunctionSpace(mesh, (\"Lagrange\", 2))\n",
    "    u = TrialFunction(V)\n",
    "    v = TestFunction(V)\n",
    "\n",
    "    # Define problems forms\n",
    "    a = inner(grad(u), grad(v)) * dx\n",
    "    f = 100 * sin(20 * x0) * v * dx\n",
    "\n",
    "    # Define boundary conditions\n",
    "    zero = Function(V)\n",
    "    with zero.vector.localForm() as zero_local:\n",
    "        zero_local.set(0.0)\n",
    "    bdofs_V = locate_dofs_topological(V, mesh.topology.dim - 1, boundary_facets)\n",
    "    bc = DirichletBC(zero, bdofs_V)\n",
    "\n",
    "    # Solve the linear system\n",
    "    u = Function(V)\n",
    "    problem = LinearProblem(\n",
    "        a, f, bcs=[bc], u=u,\n",
    "        petsc_options={\"ksp_type\": \"preonly\", \"pc_type\": \"lu\", \"pc_factor_mat_solver_type\": \"mumps\"})\n",
    "    problem.solve()\n",
    "    u.vector.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)\n",
    "\n",
    "    # Return the solution\n",
    "    return u"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "u = run_standard()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pyvista_scalar_field_plot(mesh, scalar_field, name):\n",
    "    grid = dolfinx_to_pyvista_mesh(mesh)\n",
    "    grid.point_arrays[name] = scalar_field.compute_point_values()\n",
    "    grid.set_active_scalars(name)\n",
    "    warped = grid.warp_by_scalar()\n",
    "    plotter = pyvista.PlotterITK()\n",
    "    plotter.add_mesh(warped)\n",
    "    plotter.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_scalar_field_plot(mesh, u, \"u\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Standard block case (solve for ($w_1$, $w_2$) with FEniCSx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_standard_block():\n",
    "    # Define a block function space\n",
    "    V1 = FunctionSpace(mesh, (\"Lagrange\", 2))\n",
    "    V2 = FunctionSpace(mesh, (\"Lagrange\", 2))\n",
    "    (u1, u2) = (TrialFunction(V1), TrialFunction(V2))\n",
    "    (v1, v2) = (TestFunction(V1), TestFunction(V2))\n",
    "\n",
    "    # Define problem block forms\n",
    "    aa = [[1 * inner(grad(u1), grad(v1)) * dx, 2 * inner(grad(u2), grad(v1)) * dx],\n",
    "          [3 * inner(grad(u1), grad(v2)) * dx, 4 * inner(grad(u2), grad(v2)) * dx]]\n",
    "    ff = [300 * sin(20 * x0) * v1 * dx,\n",
    "          700 * sin(20 * x0) * v2 * dx]\n",
    "\n",
    "    # Define block boundary conditions\n",
    "    zero = Function(V1)\n",
    "    with zero.vector.localForm() as zero_local:\n",
    "        zero_local.set(0.0)\n",
    "    bdofs_V1 = locate_dofs_topological((V1, V1), mesh.topology.dim - 1, boundary_facets)\n",
    "    bdofs_V2 = locate_dofs_topological((V2, V1), mesh.topology.dim - 1, boundary_facets)\n",
    "    bc1 = DirichletBC(zero, bdofs_V1, V1)\n",
    "    bc2 = DirichletBC(zero, bdofs_V2, V2)\n",
    "    bcs = [bc1, bc2]\n",
    "\n",
    "    # Assemble the block linear system\n",
    "    AA = assemble_matrix_block(aa, bcs)\n",
    "    AA.assemble()\n",
    "    FF = assemble_vector_block(ff, aa, bcs)\n",
    "\n",
    "    # Solve the block linear system\n",
    "    uu = create_vector_block(ff)\n",
    "    ksp = PETSc.KSP()\n",
    "    ksp.create(mesh.mpi_comm())\n",
    "    ksp.setOperators(AA)\n",
    "    ksp.setType(\"preonly\")\n",
    "    ksp.getPC().setType(\"lu\")\n",
    "    ksp.getPC().setFactorSolverType(\"mumps\")\n",
    "    ksp.setFromOptions()\n",
    "    ksp.solve(FF, uu)\n",
    "    uu.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)\n",
    "\n",
    "    # Split the block solution in components\n",
    "    u1u2 = (Function(V1), Function(V2))\n",
    "    with BlockVecSubVectorWrapper(uu, [V1.dofmap, V2.dofmap]) as uu_wrapper:\n",
    "        for u_wrapper_local, component in zip(uu_wrapper, u1u2):\n",
    "            with component.vector.localForm() as component_local:\n",
    "                component_local[:] = u_wrapper_local\n",
    "\n",
    "    # Return the block solution components\n",
    "    return u1u2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "u1, u2 = run_standard_block()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_scalar_field_plot(mesh, u1, \"u1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_scalar_field_plot(mesh, u2, \"u2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Error computation between the different cases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_errors(u1, u2, uu1, uu2):\n",
    "    u_1_norm = np.sqrt(mesh.mpi_comm().allreduce(assemble_scalar(inner(grad(u1), grad(u1)) * dx), op=MPI.SUM))\n",
    "    u_2_norm = np.sqrt(mesh.mpi_comm().allreduce(assemble_scalar(inner(grad(u2), grad(u2)) * dx), op=MPI.SUM))\n",
    "    err_1_norm = np.sqrt(\n",
    "        mesh.mpi_comm().allreduce(assemble_scalar(inner(grad(u1 - uu1), grad(u1 - uu1)) * dx), op=MPI.SUM))\n",
    "    err_2_norm = np.sqrt(\n",
    "        mesh.mpi_comm().allreduce(assemble_scalar(inner(grad(u2 - uu2), grad(u2 - uu2)) * dx), op=MPI.SUM))\n",
    "    print(\"  Relative error for first component is equal to\", err_1_norm / u_1_norm)\n",
    "    print(\"  Relative error for second component is equal to\", err_2_norm / u_2_norm)\n",
    "    assert np.isclose(err_1_norm / u_1_norm, 0., atol=1.e-10)\n",
    "    assert np.isclose(err_2_norm / u_2_norm, 0., atol=1.e-10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Computing errors between standard and standard block\")\n",
    "compute_errors(u, u, u1, u2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython"
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
